from tkinter import filedialog, Tk
import os
import gradio as gr


def base_ui():
    with gr.Row():
        base_model_name = base_model_ui()
        train_dataset_path = train_dataset_ui()
    with gr.Row():
        output_model_path = output_model_folder_ui()
        output_model_name = output_model_name_ui()
    return base_model_name, train_dataset_path, output_model_path, output_model_name


def base_model_ui():
    base_model_name = gr.Textbox(
        label="基础模型",
        placeholder="请选择训练的基础模型"
    )
    base_model_name_button = gr.Button(
        value="",
        variant="primary",
        icon="./theme/base_model_icon.svg",
        elem_id="base_model_name_button"
    )
    base_model_name_button.click(
        get_any_file_path,
        inputs=base_model_name,
        outputs=base_model_name,
        show_progress=False
    )
    return base_model_name


def train_dataset_ui():
    train_dataset_path = gr.Textbox(
        label="数据集",
        placeholder="请选择训练的数据集所在的文件夹"
    )
    train_dataset_path_button = gr.Button(
        value="",
        variant="primary",
        icon="./theme/dataset_icon.svg",
        elem_id="train_dataset_path_button"
    )
    train_dataset_path_button.click(
        get_folder_path,
        inputs=train_dataset_path,
        outputs=train_dataset_path,
        show_progress=False
    )
    return train_dataset_path


def output_model_folder_ui():
    output_model_path = gr.Textbox(
        label="LORA输出目录",
        placeholder="请选择训练好的LORA要存放的文件夹",
        elem_id="output_model_path"
    )
    output_model_path_button = gr.Button(
        value="",
        variant="primary",
        icon="./theme/output_model_folder_icon.svg",
        elem_id="output_model_path_button"
    )
    output_model_path_button.click(
        get_folder_path,
        inputs=output_model_path,
        outputs=output_model_path,
        show_progress=False
    )
    return output_model_path


def output_model_name_ui():
    output_model_name = gr.Textbox(
        label="LORA名字",
        placeholder="请输入要保存的LORA名字"
    )
    return output_model_name


def train_param_ui():
    with gr.Row():
        batch_size = gr.Slider(
            label="批次大小",
            value=4,
            minimum=1,
            maximum=64,
            step=1,
            interactive=True
        )
        learning_rate = gr.Slider(
            label="学习率",
            value=0.0001,
            minimum=0.0001,
            maximum=1,
            step=0.0001,
            interactive=True
        )
    with gr.Row():
        # train_steps = gr.Number(
        #     label="训练步数",
        #     value=400,
        #     minimum=1,
        #     step=1,
        #     interactive=True
        # )
        epoch = gr.Number(
            label="迭代次数",
            value=4,
            precision=0,
            interactive=True
        )
        save_model_every_n_epochs = gr.Number(
            label='迭代指定次数后保存一次训练的LORA',
            value=1,
            precision=0,
            minimum=1,
            interactive=True
        )
        return batch_size, learning_rate, epoch, save_model_every_n_epochs


def get_folder_path(folder_path=''):
    current_folder_path = folder_path
    initial_dir, initial_file = get_dir_and_file(folder_path)

    root = Tk()
    root.wm_attributes('-topmost', 1)
    root.withdraw()
    folder_path = filedialog.askdirectory(initialdir=initial_dir)
    root.destroy()

    if folder_path == '':
        folder_path = current_folder_path

    return folder_path


def get_any_file_path(file_path=''):
    current_file_path = file_path
    initial_dir, initial_file = get_dir_and_file(file_path)

    root = Tk()
    root.wm_attributes('-topmost', 1)
    root.withdraw()
    file_path = filedialog.askopenfilename(
        initialdir=initial_dir,
        initialfile=initial_file,
    )
    root.destroy()

    if file_path == '':
        file_path = current_file_path

    return file_path


def get_dir_and_file(file_path):
    dir_path, file_name = os.path.split(file_path)
    return dir_path, file_name
